{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "W = np.random.random((9,1))\n",
    "alpha = 0.1\n",
    "gamma = 0.99\n",
    "state_encode = {\n",
    "    1:[2,0,0,0,0,0,0,1],\n",
    "    2:[0,2,0,0,0,0,0,1],\n",
    "    3:[0,0,2,0,0,0,0,1],\n",
    "    4:[0,0,0,2,0,0,0,1],\n",
    "    5:[0,0,0,0,2,0,0,1],\n",
    "    6:[0,0,0,0,0,2,0,1],\n",
    "    7:[0,0,0,0,0,0,1,2],\n",
    "}\n",
    "SA_dict = {}\n",
    "def SA(state, action):\n",
    "    global SA\n",
    "    if (state,action) in SA_dict:\n",
    "        # memoization\n",
    "        return SA_dict[(state,action)]\n",
    "    else:\n",
    "        feature = np.array(state_encode[state] + [action])\n",
    "        feature = np.reshape(feature, (-1,1))\n",
    "        SA_dict[(state,action)] = feature\n",
    "        return feature\n",
    "\n",
    "def Q(state, action):\n",
    "    # 0:dashed 1:solid\n",
    "    return W.T@SA(state,action)\n",
    "\n",
    "def b():\n",
    "    dice = np.random.random()\n",
    "    # 0 means dashed, 1 means solid\n",
    "    if dice > 1/7:\n",
    "        return 0, np.random.randint(low=1, high=7)\n",
    "    else:\n",
    "        return 1, 7\n",
    "\n",
    "def update(state, next_state, action):\n",
    "    global W\n",
    "    # delta = reward(=0) + gamma * greedyQ(S+t+1, greedy action. w) -  Q(S,a,w)\n",
    "    # Because Baird's special setting, we only need to care the destination of action\n",
    "    # to determine the Q, which should be equal to the next state value(one step)\n",
    "    delta = gamma*max(Q(next_state, 0),Q(next_state,1)) - Q(state,action)\n",
    "    W = W + alpha*delta*SA(state, action)\n",
    "\n",
    "def game():\n",
    "    global W\n",
    "    for episode in range(1,31):\n",
    "        W = np.random.random((9,1))\n",
    "        _,state = b()\n",
    "        for t in range(1,10000):\n",
    "            action, next_state = b()\n",
    "            update(state, next_state, action)\n",
    "            state = next_state\n",
    "            if W.max() >= 1e10:\n",
    "                print('Episode',episode,'diverged at step', t)\n",
    "                break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 1 diverged at step 1730\n",
      "Episode 2 diverged at step 2161\n",
      "Episode 5 diverged at step 1778\n",
      "Episode 7 diverged at step 1782\n",
      "Episode 8 diverged at step 1701\n",
      "Episode 9 diverged at step 1644\n",
      "Episode 10 diverged at step 1718\n",
      "Episode 13 diverged at step 1916\n",
      "Episode 14 diverged at step 1770\n",
      "Episode 15 diverged at step 1663\n",
      "Episode 16 diverged at step 1661\n",
      "Episode 18 diverged at step 1769\n",
      "Episode 19 diverged at step 1810\n",
      "Episode 21 diverged at step 1780\n",
      "Episode 22 diverged at step 1843\n",
      "Episode 23 diverged at step 1834\n",
      "Episode 24 diverged at step 1735\n",
      "Episode 25 diverged at step 1740\n",
      "Episode 26 diverged at step 2003\n",
      "Episode 27 diverged at step 1775\n",
      "Episode 28 diverged at step 1660\n",
      "Episode 29 diverged at step 1907\n",
      "Episode 30 diverged at step 1711\n"
     ]
    }
   ],
   "source": [
    "game()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}